{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load model and push to Huggingface"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hf_repo.config import SpeechLLMModelConfig\n",
    "from hf_repo.model import SpeechLLMModel\n",
    "import torch\n",
    "\n",
    "conf = SpeechLLMModelConfig()\n",
    "model = SpeechLLMModel(conf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_state_dict(torch.load('/root/ml-research/multi-modal-llm/repo/paper_exp/checkpoints/pth/torch_model_best_checkpoints.pth'))  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conf.register_for_auto_class()\n",
    "model.register_for_auto_class(\"AutoModel\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conf.push_to_hub('skit-ai/speechllm-1.5B', commit_message=\"checkpoint update\")\n",
    "model.push_to_hub('skit-ai/speechllm-1.5B', commit_message=\"checkpoint update\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Infer after push"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Load model directly\n",
    "from transformers import AutoModel\n",
    "model = AutoModel.from_pretrained(\"skit-ai/speechllm-1.5B\", trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(\"cuda\").eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "model.generate_meta(\n",
    "    \"/root/ml-research/multi-modal-llm/datadir/data/LibriSpeech/dev-other/1255/90407/1255-90407-0004.flac\", \n",
    "    instruction=\"Give me the [SpeechActivity, Transcript, Gender, Age, Accent, Emotion] of the audio.\",\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchaudio\n",
    "\n",
    "audio_tensor, rate = torchaudio.load(\"/root/ml-research/multi-modal-llm/datadir/data/LibriSpeech/dev-other/1255/90407/1255-90407-0004.flac\")\n",
    "model.generate_meta(\n",
    "    audio_tensor=audio_tensor, \n",
    "    instruction=\"Give me the [SpeechActivity, Transcript, Gender, Age, Accen, Emotion] of the audio.\",\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import IPython\n",
    "IPython.display.Audio(\"/root/ml-research/multi-modal-llm/datadir/data/LibriSpeech/dev-other/1255/90407/1255-90407-0004.flac\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mm_trainer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
